import jax.nn as nn
import jax

from flax import linen as nn
from src.models.transformers import TruncatedLinearTransformer
from src.models.rnn import TruncatedVanillaRNN


class TruncatedLinearTransformerPredictor(nn.Module):
    d_model:int
    n_heads:int
    d_ffc:int
    n_layers:int
    output_hidden_size:int
    truncation:int

    @nn.compact
    def __call__(self,inputs):
        trf_model=TruncatedLinearTransformer(d_model=self.d_model,d_ffc=self.d_ffc,n_heads=self.n_heads,
                                                    truncation=self.truncation,n_layers=self.n_layers)
        trf_out=trf_model(inputs)
        pred=nn.Sequential([nn.Dense(self.output_hidden_size),jax.nn.relu,nn.Dense(1)])(trf_out[-1])
        return pred



class VanillaRNNPredictor(nn.Module):
    d_model:int
    output_hidden_size:int
    truncation:int

    @nn.compact
    def __call__(self,inputs):
        model=TruncatedVanillaRNN(d_model=self.d_model,truncation=self.truncation)
        out=model(inputs)
        pred=nn.Sequential([nn.Dense(self.output_hidden_size),jax.nn.relu,nn.Dense(1)])(out)
        return pred